import numpy as np
from numba import njit

@njit
def build_tri(x, T):
    B = np.zeros((T, T))
    for i in range(T):
        for j in range(i+1):
            B[i, j] = x[i-j]
    return B


def general_td_jacobian(beta, f=None, g=None):
    # f is the price age distribution, e.g. exponentially decaying if Calvo or uniform if Taylor
    # g is the info age distribution, e.g. exponentially decaying for Mankiw-Reis
    if g is None:
        T = len(f)
        g = np.zeros(T)
        g[0] = 1
    elif f is None:
        T = len(g)
        f = np.zeros(T)
        f[0] = 1
    else:
        T = len(f)

    f = f / np.sum(f)
    g = g / np.sum(g)
    G = np.cumsum(g)

    ft = f * (beta**np.arange(T))
    ft = ft / np.sum(ft)
    B = build_tri(f, T)
    F = build_tri(ft, T).T

    D = np.diag(G)

    return B @ D @ F


def hazards_to_price_age_distribution(hazards):
    f = np.concatenate([[1], np.cumprod(1 - hazards[:-1])])
    return f / np.sum(f)


def nominal_to_real(J_nom):
    J_real = np.linalg.solve(np.eye(len(J_nom)) - J_nom, J_nom)
    J_real[1:, :] = J_real[1:, :] - J_real[:-1, :]
    return J_real


def calvo_jacobian(theta, beta, T):
    f = theta ** np.arange(T)
    f = f / np.sum(f)

    M = general_td_jacobian(f=f, beta=beta)

    return M


@njit
def calvo_PC(kappa, beta, T):
    row = beta ** np.arange(T)
    J = np.zeros((T, T))
    for i in range(T):
        J[i, i:] = row[0:T-i]
    return kappa * J


def calvo_indexation_PC(kappa, wb, wf, T):
    # pi =  kappa * mc + wf * pi(+1) + wb * pi(-1)
    L = np.diag(np.ones(T-1), -1)
    return kappa * np.linalg.inv(np.eye(T) - wb * L - wf * L.T)


def taylor_jacobian(frequency, beta, T):
    f = np.zeros(T)
    T_adj = round(1/frequency)
    f[0: T_adj] = 1

    g = np.zeros(T)
    g[0] = 1

    return general_td_jacobian(f=f, beta=beta)


def calvo_trend_inflation(theta, mu, sig, beta, eps, T):

    x_star = np.log((1 - beta * theta * np.exp((eps - 1) * mu + 0.5 * ((eps - 1) ** 2) * (sig ** 2))) /
                    (1 - beta * theta * np.exp(eps * mu + 0.5 * (eps ** 2) * (sig ** 2))))

    f = np.zeros(T)
    for t in range(T):
        f[t] = ((beta * theta) ** t) * (eps * np.exp((eps * mu + 0.5 * (eps ** 2) * (sig ** 2)) * t) -
                                        (eps - 1) * np.exp(x_star + ((eps - 1) * mu + 0.5 * ((eps - 1) ** 2) * (sig ** 2)) * t))

    f = f / np.sum(f)

    b = (theta * np.exp((eps - 1) * mu)) ** np.arange(T)
    b = b / np.sum(b)

    F = np.zeros((T, T))
    B = np.zeros((T, T))

    for t in range(T):
        F[t, t:] = f[:T-t]
        B[t:, t] = b[:T-t]

    M = B @ F
    J = nominal_to_real(M)

    return M, J, f, b

